{% extends 'base.attached.class' %}

{% block content %}
namespace {
    class {{ class_name }} {

        public function __construct($model) {
            $this->forest = array();
            foreach($model as $tree){
                array_push($this->forest, new {{ class_name }}\Tree(
                    $tree['lefts'], $tree['rights'],
                    $tree['thresholds'], $tree['indices'],
                    $tree['classes']
                ));
            }
        }

        private function findMax($nums) {
            $i = 0; $l = count($nums); $idx = 0;
            for (; $i < $l; $i++) {
                $idx = $nums[$i] > $nums[$idx] ? $i : $idx;
            }
            return $idx;
        }


        private function compute($features) {
            $nTrees = count($this->forest);
            $nClasses = count($this->forest[0]->classes[0]);
            $probas = array_fill(0, $nClasses, 0);
            for ($i = 0; $i < $nTrees; $i++) {
                $temp = $this->forest[$i]->compute($features);
                for ($j = 0; $j < $nClasses; $j++) {
                    $probas[$j] += $temp[$j];
                }
            }
            for ($j = 0; $j < $nClasses; $j++) {
                $probas[$j] /= $nTrees;
            }
            return $probas;
        }

        public function predictProba($features) {
            return $this->compute($features);
        }

        public function predict($features) {
            return $this->findMax($this->predictProba($features));
        }

    }
}

namespace {{ class_name }} {
    class Tree {
        public function __construct($lefts, $rights, $thresholds, $indices, $classes) {
            $this->lefts = $lefts;
            $this->rights = $rights;
            $this->thresholds = $thresholds;
            $this->indices = $indices;
            $this->classes = $classes;
        }

        private function normVals($nums) {
            $l = count($nums);
            $result = [];
            $sum = 0.;

            for ($i = 0; $i < $l; $i++) {
                $sum += $nums[$i];
            }
            if($sum === 0) {
                for ($i = 0; $i < $l; $i++) {
                    $result[$i] = 1. / $l;
                }
            } else {
                for ($i = 0; $i < $l; $i++) {
                    $result[$i] = $nums[$i] / $sum;
                }
            }
            return $result;
        }

        public function compute($features) {
            $node = (func_num_args() > 1) ? func_get_arg(1) : 0;
            while ($this->thresholds[$node] != -2) {
                if ($features[$this->indices[$node]] <= $this->thresholds[$node]) {
                    $node = $this->lefts[$node];
                } else {
                    $node = $this->rights[$node];
                }
            }
            return $this->normVals($this->classes[$node]);
        }
    }
}

namespace {
  if ($argc > 1) {

      // Features:
      array_shift($argv);
      $features = $argv;

      // Model data:
      {{ model }}

      // Estimator:
      $clf = new {{ class_name }}($model);

      {% if is_test or to_json %}
      // Get JSON:
      $prediction = $clf->predict($features);
      $probabilities = $clf->predictProba($features);
      fwrite(STDOUT, json_encode(array("predict" => $prediction, "predict_proba" => $probabilities)));
      {% else %}
      // Get class prediction:
      $prediction = $clf->predict($features);
      fwrite(STDOUT, "Predicted class: #" . strval($prediction) . "\n");

      // Get class probabilities:
      $probabilities = $clf->predictProba($features);
      for ($i = 0; $i < count($probabilities); $i++) {
          fwrite(STDOUT, "Probability of class #" . strval($i) . " = " . strval($probabilities[$i]) . "\n");
      }
      {% endif %}

  }
}
{% endblock %}